import matplotlib.pyplot as plt


def draw_pic(data: dict, max_num=1):
    fig = plt.figure()
    for i in data.keys():
        x = [j for j in range(len(data[i]))]
        y = []
        for k in data[i]:
            if k > max_num:
                k = max_num
            y.append(k)
        plt.plot(x, y, label=i)
    plt.grid()
    plt.legend()
    plt.show()

if __name__ == "__main__":
    log_path = "logs/train_2.log"
    with open(log_path, 'r')as f:
        lines = f.readlines()
    res = {
        "loss": [],
        "loss_classifier": [],
        "loss_box_reg": [],
        "loss_objectness": [],
        "loss_rpn_box_reg": []
    }
    for line in lines:
        if " eta: " not in line:
            continue
        line = line.split(" ")
        for i in range(len(line)):
            if line[i][:-1] in res.keys():
                res[line[i][:-1]].append(float(line[i+1]))
    draw_pic(res)
